
import json
import os
import os.path as op
import sys
from utils.extract_features import (convert_coords_npz,
                                    generate_geo_index_wrapper,
                                    extract_features,
                                    merge_data)

class DataGenFlow(object):
    """
    stage1: ConvertCoords <- input(['coords_txt'])
              |
              V
              output(['coords_npz'])---+
                                       |
                                       V
    stage2: GenerateIndex <- input(['coords_npz', 'depth_txt'])
              |
              V
              output(['inds_npz'])---+
                                     |
                                     V
    stage3: ExtractFeat <- input(['inds_npz', 'geotiff'])
              |
              V
              output(['feat_label_npz'])
                                 |
                                 |
                                 V
    stage4: MergeData <- input(['feat_label_npz'])
              |
              V
              output(['train_npz', 'val_npz'])
    """
    DATA_BASE = op.abspath(op.join(op.dirname(__file__),
                                   '../../Data'))
    DATA_CFG = op.join(op.dirname(__file__),
                       "data.cfg")

    def __init__(self, dates, patch_size=2):
        self.dates = dates
        self.patch_size = patch_size
        self._load_data_cfg()
        self.dump_dir = op.join(DataGenFlow.DATA_BASE,
                                "train")
        self.temp_dir = op.join(DataGenFlow.DATA_BASE,
                                "train/.temp")
    def build_data(self):
        """
        Build dataset.
        """
        self._build_dir()
        self._routes_generator()
        for route in self.routes:
            self._one_data_gen_route(route)
        
    def _build_dir(self):
        """
        Preparation for the directories.
        """
        if not op.exists(self.dump_dir):
            try:
                os.makedirs(self.dump_dir)
            except:
                print("Error: create directory failed!")
                sys.exit(0)
        if not op.exists(self.temp_dir):
            try:
                os.makedirs(self.temp_dir)
            except:
                print("Error: create directory failed!")
                sys.exit(0)

    def _routes_generator(self):
        """
        Generate data routes, one route is composed by date and gf5_id.
        """
        self.routes = list()
        for date in self.dates:
            if (date in self.cfg_data["depth_data"]) and \
                (self.cfg_data["depth_data"][date].get("geo_datas", [])):
                    for gf5_id in self.cfg_data["depth_data"][date]["geo_datas"]:
                        self.routes.append((date, gf5_id))
            else:
                continue

    def _load_data_cfg(self):
        """
        Load data.cfg and parse the content.
        """
        if not os.path.exists(DataGenFlow.DATA_CFG):
            print("Error: data.cfg is missing.")
            sys.exit(0)
        with open(DataGenFlow.DATA_CFG, 'r') as f:
            self.cfg_data = json.load(f)

    def _one_data_gen_route(self, route):
        """
        One data generation route include 4 stages:
          1. Convert coords.txt to coords.npz
          2. Generate geo index
          3. Extract features from geotiff file
          4. Merge data through all data routes
        """
        errors = self._gen_file_path(route)
        if errors:
            print("\n".join(errors))
        else:
            errors = self._stage_convert_coords_npz()
            if errors: 
                print("\n".join(errors))
                return
            errors = self._stage_generate_geo_index()
            if errors: 
                print("\n".join(errors))
                return
            errors = self._stage_extract_features()
            if errors: 
                print("\n".join(errors))
                return
            errors = self._stage_merge_data()
            # if errors: return

    def _gen_file_path(self, route):
        """
        Generate paths for intermediate files.
        """
        errors = list()        
        date, gf5_id = route

        self.curr_depth_txt = op.join(DataGenFlow.DATA_BASE,
                self.cfg_data["depth_data"][date]["depth_txt"])
        if not op.exists(self.curr_depth_txt):
            errors.append("{} is missing.".format(self.curr_depth_txt))

        self.curr_geotiff = op.join(DataGenFlow.DATA_BASE,
                self.cfg_data["geo_data"][gf5_id]["geotiff"])
        if not op.exists(self.curr_geotiff):
            errors.append("{} is missing.".format(self.curr_geotiff))

        self.curr_coords_txt = op.join(DataGenFlow.DATA_BASE,
                self.cfg_data["geo_data"][gf5_id]["coords_txt"])
        if not op.exists(self.curr_coords_txt):
            errors.append("{} is missing.".format(self.curr_coords_txt))

        self.curr_coords_npz = op.join(self.temp_dir,
                                       gf5_id+".coords.npz")
        self.curr_inds_npz = op.join(self.temp_dir,
                                     gf5_id + \
                                     ".patch_"+str(self.patch_size)+".inds.npz")
        self.curr_feat_label_npz = op.join(self.temp_dir,
                                           gf5_id + \
                                     ".patch_"+str(self.patch_size)+".feat_train.npz")

        self.curr_label_png = op.join(self.temp_dir,
                                           gf5_id + \
                                     ".patch_"+str(self.patch_size)+".label.png")
        self.curr_mask = self.cfg_data["geo_data"][gf5_id].get("mask", [])
        self.curr_proportion = self.cfg_data["geo_data"][gf5_id].get("proportion", 1)
        self.curr_train_npz = op.join(self.dump_dir, "train.npz")
        self.curr_val_npz = op.join(self.dump_dir, "val.npz")

        self.curr_route = route

        if errors:
            return errors


    def _stage_convert_coords_npz(self):
        """
        Stage convert coords.txt to coords.npz
        """
        self.curr_input_files = {"coords_txt": self.curr_coords_txt}
        self.curr_out_files = {"coords_npz": self.curr_coords_npz}
        convert_coords_npz(data_route=self.curr_route,
                           input_files=self.curr_input_files,
                           out_files=self.curr_out_files)

    def _stage_generate_geo_index(self):
        """
        Stage generate geo index data
        """
        self.curr_input_files = {"depth_txt": self.curr_depth_txt,
                                 "coords_npz": self.curr_coords_npz}
        self.curr_out_files = {"inds_npz": self.curr_inds_npz}
        generate_geo_index_wrapper(data_route=self.curr_route,
                                   mask=self.curr_mask,
                                   input_files=self.curr_input_files,
                                   out_files=self.curr_out_files,
                                   patch_size=self.patch_size)

    def _stage_extract_features(self):
        """
        Stage extract pixel values from geotiff file
        """
        self.curr_input_files = {"inds_npz": self.curr_inds_npz,
                                 "geotiff": self.curr_geotiff}
        self.curr_out_files = {"feat_label_npz": self.curr_feat_label_npz,
                               "label_png": self.curr_label_png}
        extract_features(data_route=self.curr_route,
                         mask=self.curr_mask,
                         input_files=self.curr_input_files,
                         out_files=self.curr_out_files)

    def _stage_merge_data(self):
        """
        Stage merge data
        """
        self.curr_input_files = {"feat_label_npz": self.curr_feat_label_npz}
        self.curr_out_files = {"train_npz": self.curr_train_npz,
                               "val_npz": self.curr_val_npz}
        merge_data(data_route=self.curr_route,
                   input_files=self.curr_input_files,
                   proportion=self.curr_proportion,
                   out_files=self.curr_out_files)

if __name__ == "__main__":
    flow = DataGenFlow(["2019_03_27"])
    flow.build_data()
